from machine import *
import stats

class Trivial(Machine):
  def __init__(self, stats_data):
    self.stats = stats_data

class User_Avg(Trivial):
  name = "user_avg"
  def est_rating(self, mid, uid):
    return self.stats.basic_user[int(uid)][0]

class Movie_Avg(Trivial):
  name = "movie_avg"
  def est_rating(self, mid, uid):
    return self.stats.basic_movie[int(mid)][0]

class MovieUser_Avg(Trivial):
  name = "avg_avg"
  def est_rating(self, mid, uid):
    return (self.stats.basic_user[int(uid)][0] + self.stats.basic_movie[int(mid)][0]) / 2

if __name__ == "__main__":
  stats_data = stats.Stats(tr, stats.STORE)
  stats_data.read_data()
  print "user_avg:"
  User_Avg(stats_data).calc_rmse(False)
  print "movie_avg:"
  Movie_Avg(stats_data).calc_rmse(False)
  print "movie_user_avg:"
  MovieUser_Avg(stats_data).calc_rmse(False)
  